import torch
import torchvision
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt


def get_dataloader_workers():
    return 0


def load_data_mnist(batch_size):
    '''
        下载并加载数据集
    '''
    Tr = transforms.ToTensor()

    mnist_train = torchvision.datasets.MNIST(root="../data",
                                             train=True,
                                             transform=Tr,
                                             download=True)
    mnist_test = torchvision.datasets.MNIST(root="../data",
                                            train=False,
                                            transform=Tr,
                                            download=True)

    return (data.DataLoader(mnist_train,
                            batch_size,
                            shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test,
                            batch_size,
                            shuffle=False,
                            num_workers=get_dataloader_workers()))
